/*
    Copyright 2006-2012 Patrik Jonsson, sunrise@familjenjonsson.org

    This file is part of Sunrise.

    Sunrise is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 3 of the License, or
    (at your option) any later version.

    Sunrise is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with Sunrise.  If not, see <http://www.gnu.org/licenses/>.

*/

/**  \file
    Declarations of the grid_base and adaptive_grid classes. This file
    includes all the other grid_xxx.h files.  Never include them
    directly. */


#ifndef __grid__
#define __grid__

#include "config.h"
#include "mcrx-types.h"
#include "mcrx-debug.h"
#include "blitz/numinquire.h"
#include <ostream>
#include <vector> 
#include "ray.h"
#include <algorithm>
#include "mpi_util.h"
#include "hilbert.h"

namespace mcrx {
  template <typename> class grid_cell;
  template <typename> class cell_tracker;
  template <typename> class grid_base;
  template <typename, typename> class dynamic_grid;
  template <typename> class octogrid;
  template <typename> class adaptive_grid;
}

namespace CCfits {
  class FITS;
  class ExtHDU;
}

#include "grid_cell.h"
#include "cell_tracker.h"
#include "grid_impl.h"

// *** top-level class adaptive_grid ***

/** Class representing the top level of a hierarchical grid
    structure. This is a special class because there are some
    operations that only make sense from the top level. Most grid
    operations are just passed down to the actual grid object. */
template <typename cell_data_type>
class mcrx::adaptive_grid {
public:
  typedef mcrx::T_float T_float;
  typedef grid_cell<cell_data_type> T_cell;
  typedef typename T_cell::T_grid T_grid;
  typedef cell_tracker<cell_data_type> T_cell_tracker;
  typedef cell_data_type T_data;
  typedef hilbert::qpoint T_qpoint;
  typedef hilbert::cell_code T_code;
  typedef T_cell_tracker iterator;

  friend class cell_tracker<cell_data_type> ;

  class T_location_request {
#ifdef HAVE_BOOST_SERIALIZATION
  friend class boost::serialization::access;
  template<class T_arch>
  void serialize(T_arch& ar, const unsigned int version) {};
#endif
  };
  class T_location_response {
#ifdef HAVE_BOOST_SERIALIZATION
  friend class boost::serialization::access;
  template<class T_arch>
  void serialize(T_arch& ar, const unsigned int version) {};
#endif
};

private:
  // undefined to prevent implicit generation
  adaptive_grid (const T_grid&);
  adaptive_grid& operator= (const adaptive_grid&);
  void structure_plot_point (std::ostream&,const vec3d&) const;

  /** Helper function that creates subgrids according to the supplied
      structure vector. The subgrid objects are all in one contiguous
      block to improve cache performance. It also default-constructs
      data objects for the cells, also in one contiguous block. */
  void create_structure(const std::vector<bool>& structure);

  void load_recursive (std::vector<bool>::const_iterator& refined,
		       const std::vector<bool>::const_iterator& e,
		       void*& placement);

  void skip_suboctree(std::vector<bool>::const_iterator& refined,
		      const std::vector<bool>::const_iterator& e) const;


  /** Determines the domain decomposition. */
  void domain_decomposition(const std::vector<bool>& structure, const int level);
  void calculate_work_recursive(std::vector<bool>::const_iterator& structure,
				const std::vector<bool>::const_iterator& e,
				const hilbert::cell_code code,
				std::vector<int>& work,
				const int decomp_level) const;

protected:

  vec3d min_; ///< Minimum point of the grid cube.
  vec3d max_; ///< Maximum point of the grid cube.

  T_cell c_; ///< The root cell of the octree.

  /// Memory buffer for the allocation of sub grids, if used.
  void* subgrid_block_;

  /// Memory buffer for the allocation of data, if used.
  void* data_block_;

  /// The partitions in the domain decomposition.
  std::vector<hilbert::cell_code> domain_;
  /** The starting cell index of the different domains. This is used
      to figure out which part of the cell data is owned by the
      different tasks. */
  std::vector<size_t> domainstart_;

public:
  /// Creates an adaptive_grid with specified dimensions and no parent cell.
  adaptive_grid (const vec3d & mi, const vec3d & ma):
    min_(mi), max_(ma), c_(mpi_rank()), subgrid_block_ (0), data_block_(0) {};

  /// Creates an adaptive_grid with information from a structure vector.
  adaptive_grid (const vec3d & mi, const vec3d & ma,
		 const std::vector<bool>& structure);
  /// Creates an adaptive_grid with information from a gridstructure HDU.
  adaptive_grid (CCfits::ExtHDU& file, int,
		 const vec3d& =vec3d(0,0,0));
  // creates an adaptive grid with specified base dimensions and uses
  // the grid_factory object to create the grid structure and cell data.
  template <typename T_factory>
  adaptive_grid (const vec3d & mi, const vec3d & ma, T_factory& f);
  ~adaptive_grid ();

  /// Converts a 3D position into the quantized integer coordinates
  /// used by the octree.
  T_qpoint qpoint(vec3d p) const { 
    assert_in_grid(p);
    return T_qpoint((p-getmin())/getsize()); };

  /** Converts a quantized integer coordinate used by the octree into
      a 3D position. Practically, the position returned is always the
      min corner of the cell, so for example the 0-level position
      returns (0,0,0). It is NOT the center of the cell. */
  vec3d real_point(T_qpoint qp) const { 
    return vec3d(qp)*getsize()+getmin(); };

  T_float real_point(T_qpoint qp, int dim) const { 
    return hilbert::qpoint::unquantize(qp[dim], qp.level())*
      (max_[dim]-min_[dim])+min_[dim]; };

  // simple data retrieval functions
  const vec3d& getmin() const {return min_; };
  const vec3d& getmax() const {return max_; };

  vec3d getsize() const {return max_-min_; };
  T_float getsize(int dim) const {return max_[dim]-min_[dim]; };

  /// Returns true if the point is within the grid
  bool in_grid(const vec3d& p) const {
    return all(p>=getmin()) && all(p<=getmax()); };

  void assert_in_grid(const vec3d& p) const {
#ifdef NDEBUG
    assert(all(p>=getmin()-getsize()*1e-10));
    assert(all(p<=getmax()+getsize()*1e-10));
#endif
  };

  T_location_response location_request(const T_location_request&) const {
    throw 1; return T_location_response();};

  /** Initializes a tracker for rays that have been shipped over onto
      this task. We only know the cell code, so we must do a locate()
      on it. */
  void initialize_tracker(const T_code& code, const vec3d& pos, 
			  T_cell_tracker& c) {
    c.restart(this);
    c.locate(code);
    if(!c.is_leaf()) {
      // we have higher resolution than the qpoint in the tracker, so
      // now we need to figure out which subcell the position is in.
      c.locate(qpoint(pos), true);
    }
    assert(c.task()==mpi_rank());
};

  /** Locate the cell containing a position. The int argument is
      unused, it's for distinguishing which thread makes a request
      when a location request has to be made to other tasks, but we
      never have to do that. Note that here we handle accept_outside
      by truncating the position to the grid extent. Since we start
      the search at top lev
el, we no longer have to worry about it
      being outside after that. */
  T_cell_tracker locate (vec3d p, int, bool accept_outside) {
    if(accept_outside) 
      p=truncate_to_box(p,getmin(), getmax());
    else if(!in_grid(p)) return T_cell_tracker();
    if(c_.is_leaf())
      return T_cell_tracker(this);
    else {
      const T_qpoint qp(qpoint(p));
      T_cell_tracker t(this);
      t.locate(qp, false);
      return t;
    }
  };

  std::pair<T_cell_tracker, T_float> 
  intersection_from_without (const ray_base& ray, int thread) {
    const T_float l = ray.intersect_cube(getmin(), getmax(), false);
    if(l==l)
      // outside
      return std::make_pair(locate(ray.position() + l*ray.direction(), 
				   thread, true), l); 
    else
      // inside
      return std::make_pair(locate(ray.position(), thread, false), l);
  };

  std::vector<bool> position_ownership(std::vector<vec3d>& pos) const;

  // outputs grid structure to a stream
  void structure_plot (std::ostream&) const;
  /** Returns a structure vector for the grid, optionally including a
      vector of the cell codes. */
  std::vector<bool> get_structure (std::vector<T_code>* codes=0,
				   std::vector<T_qpoint::T_pos>* qpos=0) const {
    std::vector<bool> structure;
    T_cell_tracker c(const_cast<adaptive_grid*>(this));
    while(!c.at_end()) {
      structure.push_back(!c->is_leaf());
      if(codes)
	codes->push_back(c.code());
      if(qpos)
	qpos->push_back(c.qpos().pos());
      c.dfs();
    }
    return structure;
  };

  /// saves grid structure to a FITS HDU (include grid-fits.h)
  void save_structure (CCfits::FITS& file, const std::string& hdu,
		       const std::string& length_unit = "",
		       bool save_codes=false) const;
  /// Returns the number of cells in the grid.
  int n_cells () const;
  /// Returns the number in the grid enumeration of the specified cell.
  int get_cell_number(const T_cell_tracker& c) const;

  /** Returns statistics about the number of cells at different
      refinement levels in the grid. */
  std::vector<int> statistics () const;

  iterator begin () {T_cell_tracker t(this); if(!t.is_leaf()) ++t; return t; };
  iterator end () {T_cell_tracker t(this); t.reset(); return t; };

  class const_iterator;
  const_iterator begin () const {
    const_iterator t(this); if(!t.is_leaf()) ++t; return t; };
  const_iterator end () const { return const_iterator(); };

  T_cell_tracker root() { return T_cell_tracker(this); };
  bool check_integrity() const;

  /// Returns the start and number of the cells which are in our domain.
  std::pair<size_t, size_t> domain_index() const { 
    assert(domainstart_.size()==mpi_size()+1);
    return std::make_pair(domainstart_[mpi_rank()], 
			  domainstart_[mpi_rank()+1] - 
			  domainstart_[mpi_rank()]); };

  bool shared_domain() const { return false; };
};



// *** grid_base methods ***

/** Constructor creates grid from a structure vector. When
    execution enters this constructor, the grid_base constructor has
    completed. This function then looks for the refinement mechanism
    and ensures that the file is consistent with the class.  A memory
    block is allocated to hold all the subgrids, so that they are well
    localized in memory, and finally load_recursive is called to
    actually create the structure. */
template < typename cell_data_type>
mcrx::adaptive_grid<cell_data_type>::
  adaptive_grid (const vec3d & mi, const vec3d & ma,
		 const std::vector<bool>& structure) :
    g_(), min_(mi), max_(ma), subgrid_block_ (0)
{
  create_structure(structure);
}


template < typename cell_data_type>
mcrx::adaptive_grid<cell_data_type>::~adaptive_grid ()
{
  if (data_block_) {
    // call destructors for all data objects and then deallocate block
    for (iterator c = begin (), e=end(); c != e; ++c) {
      c->data()->~T_data();
      c->unset_data();
    }
    operator delete (data_block_);
  }
  if (subgrid_block_) {
    // Go through the hierarchy and call destructors for grid_cell, not delete
    c_.placement_new_destructor();
    operator delete (subgrid_block_);
  }
}

template < typename cell_data_type>
void mcrx::adaptive_grid<cell_data_type>::
create_structure(const std::vector<bool>& structure)
{   
  // on a split domain, how do we figure out how many subgrids we own?
  // Maybe easiest to just create the grids and then blockify them?
  /// \todo we are still allocating enough storage for the full domain

  // Figure out how many sub grids we have (we have one for every
  // true in the structure vector)
  int n_subgrids = std::count (structure.begin(), structure.end(), true);
  const size_t block_size = n_subgrids*sizeof(T_grid); 
  std::cout << "Allocating a memory block for " << n_subgrids
	    << " subgrids, " << block_size*1.0/(1024*1024) << " MB."  << std::endl;
  subgrid_block_ = operator new (block_size);
  
  std::vector<bool>::const_iterator si = structure.begin();
  void* sb = subgrid_block_;

  load_recursive (si, structure.end(), sb);
  assert(si==structure.end()); // make sure we used exactly all elements
  assert(reinterpret_cast<char*>(sb)<=
	 reinterpret_cast<char*>(subgrid_block_)+block_size);

  // Now allocate the data objects in one contiguous chunk
  const int nc = n_cells ();
  const size_t data_block_size = nc*sizeof (T_data);
  std::cout << "Allocating a memory block for " << nc 
	    << " cell data objects, " << data_block_size*1.0/(1024*1024) 
	    << " MB."  << std::endl;
  data_block_ = operator new (data_block_size);

  int i = 0;
  T_data* block_pointer = reinterpret_cast<T_data*> (data_block_);
  for (iterator c = begin (), e=end(); c != e; ++c, ++i) {
    // use placement new to make a data object 
    T_data* d = new (block_pointer++) T_data ();
    c->set_data(d );
  }
}

template <typename cell_data_type>
int mcrx::adaptive_grid<cell_data_type>::n_cells () const
{
  int n=0;
  const_iterator i=begin(), e=end();
  while(i != e) { ++n; ++i; }
  return n;
}

template <typename cell_data_type>
std::vector<int>
mcrx::adaptive_grid<cell_data_type>::statistics () const
{
  std::vector<int> stats;
  for(const_iterator c=begin(), e=end(); c!=e; ++c) {
    const int l = c.code().level();
    if(stats.size()<=l)
      stats.resize(l+1,0);
    ++stats[l];
  }
  return stats;
}

template <typename cell_data_type>
int mcrx::adaptive_grid<cell_data_type>::
get_cell_number (const T_cell_tracker& c) const
{
  /*
  int n = 0;
  for (const_iterator i = begin (), e = end (); i != e; ++i, ++n)
    if (&(*i)==c.cell())
      return n;
  */
  throw 1;
}

template <typename cell_data_type>
void
mcrx::adaptive_grid<cell_data_type>::
structure_plot_point (std::ostream& o, const vec3d& p) const
{
    o << p [0]  <<  '\t'
     << p [1]  <<  '\t'
      << p [2]  <<  '\n';
}


/** Saves the grid structure as a text file to the specified
    stream. The structure consists of a bunch of line segments that
    you can feed to e.g. gnuplot to plot the grid. */
template <typename cell_data_type>
void 
mcrx::adaptive_grid<cell_data_type>::
structure_plot (std::ostream& o) const
{
  o.precision (12); 
  const_iterator j = begin(), e=end();
  while (j != e) {
    const vec3d a = j.getmin(); 
    const vec3d b = j.getmax ();
    const vec3d p1 = a;
    const vec3d p2 (b [0], a [1], a [2]);
    const vec3d p3 (b [0], b [1], a [2]);
    const vec3d p4 (a [0], b [1], a [2]);
    const vec3d p5 (a [0], a [1], b [2]);
    const vec3d p6 (b [0], a [1], b [2]);
    const vec3d p7 (b [0], b [1], b [2]);
    const vec3d p8 (a [0], b [1], b [2]);

    structure_plot_point (o, p1);
    structure_plot_point (o, p2);
    structure_plot_point (o, p3);
    structure_plot_point (o, p4);
    structure_plot_point (o, p1);
    structure_plot_point (o, p5);
    structure_plot_point (o, p6);
    structure_plot_point (o, p7);
    structure_plot_point (o, p8);
    structure_plot_point (o, p5);
    o << "\n";
    structure_plot_point (o, p2);
    structure_plot_point (o, p6);
    o << "\n";
    structure_plot_point (o, p4);
    structure_plot_point (o, p8);
    o << "\n";
    structure_plot_point (o, p3);
    structure_plot_point (o, p7);
    o << "\n";
   ++j;
  }
}

/// Checks that all leaf cells have a data object attached.
template <typename cell_data_type>
bool mcrx::adaptive_grid<cell_data_type>::check_integrity () const
{
  bool bad=false;
  for (const_iterator c = begin (), e = end (); c!=e; ++c) {
    bad = bad | (!c->data());
  }

  if (bad)
    std::cerr << "Grid integrity check failed!" << std::endl;

  return bad;    
}

/** Checks whether we own a bunch of positions. \todo We can only
    encode positions inside the grid. what do we do about positions
    outside? For now we assign those to task 0. */
template <typename cell_data_type>
std::vector<bool>
mcrx::adaptive_grid<cell_data_type>::
position_ownership(std::vector<vec3d>& pos) const 
{
  using namespace hilbert;

  if(mpi_size()==1) 
    return std::vector<bool>(pos.size(), true);
  const int this_task = mpi_rank();

  std::vector<bool> own;
  for(int i=0; i<pos.size(); ++i) {
    if(in_grid(pos[i])) {
      const cell_code h(encode_point(vec3d((pos[i]-getmin())/(getmax()-getmin())), 
				     domain_.back().n_));
      bool own_this = domain_[this_task]<=h && h<domain_[this_task+1];
      own.push_back(own_this);
    }
    else
      own.push_back(this_task==0);
  }
  return own;
}

/** The const_iterator is just a wrapper around a normal iterator. */
template<typename T>
class mcrx::adaptive_grid<T>::const_iterator :
  public std::iterator<std::forward_iterator_tag, 
		       typename mcrx::grid_cell<T> >  {
private:
  typedef typename adaptive_grid<T>::iterator T_iterator_impl;
  typedef adaptive_grid<T> T_topgrid;

  T_iterator_impl i_;

public:
  const_iterator() : i_() {};

  const_iterator(const T_topgrid* g) : 
    i_(const_cast<T_topgrid*>(g)) {};
  const_iterator(const T_iterator_impl i) : i_(i) {};

  bool operator==(const const_iterator& rhs) const {return i_==rhs.i_; };
  bool operator!=(const const_iterator& rhs) const {return i_!=rhs.i_; };

  operator bool() const { return bool(i_); };

  const T_cell& operator*() const { return *i_.cell(); };
  const T_cell* operator->() const { return i_.cell(); };
  const_iterator& operator++ () { // prefix
    ++i_; return *this; };
  const_iterator operator++ (int){ // postfix
    const_iterator i (*this); operator++ (); return i;};

  typename T_iterator_impl::T_code code() const { return i_.code(); };
  typename T_iterator_impl::T_qpoint qpos() const { return i_.qpossphgri(); };
  int task() const { return i_.task(); };
  bool is_leaf() const { return i_.is_leaf(); };
  vec3d getsize() const { return i_.getsize(); };
  vec3d getinvsize() const { return i_.getinvsize(); };
  vec3d getmin() const { return i_.getmin(); };
  vec3d getmax() const { return i_.getmax(); };
  vec3d getcenter() const { return i_.getcenter(); };
  bool contains(const vec3d& p) const { return i_.contains(p); };
  void assert_contains(const vec3d& p) const { i_.assert_contains(p); };
  T_float volume() const { return i_.volume(); };
  typename T_cell::T_data* data() const { return i_.data(); };

  std::ostream& print(std::ostream& os) const {
    os << "Const " << i_;
    return os;
  }
};

#endif
